<CourseFloatingBanner chapter={2}
  classNames="absolute z-10 right-0 top-0"
  notebooks={[
    {label: "Google Colab", value: "https://colab.research.google.com/github/huggingface/notebooks/blob/main/course/en/chapter11/section3.ipynb"},
]} />

# Supervised Fine-Tuning

Supervised Fine-Tuning (SFT) is a process primarily used to adapt pre-trained language models to follow instructions, engage in dialogue, and use specific output formats. While pre-trained models have impressive general capabilities, SFT helps transform them into assistant-like models that can better understand and respond to user prompts. This is typically done by training on datasets of human-written conversations and instructions.

This page provides a step-by-step guide to fine-tuning the [`deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B`](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) model using the [`SFTTrainer`](https://huggingface.co/docs/trl/en/sft_trainer). By following these steps, you can adapt the model to perform specific tasks more effectively. 

## When to Use SFT

Before diving into implementation, it's important to understand when SFT is the right choice for your project. As a first step, you should consider whether using an existing instruction-tuned model with well-crafted prompts would suffice for your use case. SFT involves significant computational resources and engineering effort, so it should only be pursued when prompting existing models proves insufficient.

<Tip>
Consider SFT only if you:
- Need additional performance beyond what prompting can achieve
- Have a specific use case where the cost of using a large general-purpose model outweighs the cost of fine-tuning a smaller model
- Require specialized output formats or domain-specific knowledge that existing models struggle with
</Tip>

If you determine that SFT is necessary, the decision to proceed depends on two primary factors:

### Template Control
SFT allows precise control over the model's output structure. This is particularly valuable when you need the model to:
1. Generate responses in a specific chat template format
2. Follow strict output schemas
3. Maintain consistent styling across responses

### Domain Adaptation
When working in specialized domains, SFT helps align the model with domain-specific requirements by:
1. Teaching domain terminology and concepts
2. Enforcing professional standards
3. Handling technical queries appropriately
4. Following industry-specific guidelines

<Tip>
Before starting SFT, evaluate whether your use case requires:
- Precise output formatting
- Domain-specific knowledge
- Consistent response patterns
- Adherence to specific guidelines

This evaluation will help determine if SFT is the right approach for your needs.
</Tip>

## Dataset Preparation

The supervised fine-tuning process requires a task-specific dataset structured with input-output pairs. Each pair should consist of:
1. An input prompt
2. The expected model response
3. Any additional context or metadata

The quality of your training data is crucial for successful fine-tuning. Let's look at how to prepare and validate your dataset:

<iframe
  src="https://huggingface.co/datasets/HuggingFaceTB/smoltalk/embed/viewer/all/train?row=0"
  frameborder="0"
  width="100%"
  height="360px"
></iframe>

## Training Configuration

The success of your fine-tuning depends heavily on choosing the right training parameters. Let's explore each important parameter and how to configure them effectively:

The SFTTrainer configuration requires consideration of several parameters that control the training process. Let's explore each parameter and their purpose:

1. **Training Duration Parameters**:
   - `num_train_epochs`: Controls total training duration
   - `max_steps`: Alternative to epochs, sets maximum number of training steps
   - More epochs allow better learning but risk overfitting

2. **Batch Size Parameters**:
   - `per_device_train_batch_size`: Determines memory usage and training stability
   - `gradient_accumulation_steps`: Enables larger effective batch sizes
   - Larger batches provide more stable gradients but require more memory

3. **Learning Rate Parameters**:
   - `learning_rate`: Controls size of weight updates
   - `warmup_ratio`: Portion of training used for learning rate warmup
   - Too high can cause instability, too low results in slow learning

4. **Monitoring Parameters**:
   - `logging_steps`: Frequency of metric logging
   - `eval_steps`: How often to evaluate on validation data
   - `save_steps`: Frequency of model checkpoint saves

<Tip>
Start with conservative values and adjust based on monitoring:
- Begin with 1-3 epochs
- Use smaller batch sizes initially
- Monitor validation metrics closely
- Adjust learning rate if training is unstable
</Tip>

## Implementation with TRL

Now that we understand the key components, let's implement the training with proper validation and monitoring. We will use the `SFTTrainer` class from the Transformers Reinforcement Learning (TRL) library, which is built on top of the `transformers` library. Here's a complete example using the TRL library:

```python
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
import torch

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load dataset
dataset = load_dataset("HuggingFaceTB/smoltalk", "all")

# Configure model and tokenizer
model_name = "HuggingFaceTB/SmolLM2-135M"
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_name).to(
    device
)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)
# Setup chat template
model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)

# Configure trainer
training_args = SFTConfig(
    output_dir="./sft_output",
    max_steps=1000,
    per_device_train_batch_size=4,
    learning_rate=5e-5,
    logging_steps=10,
    save_steps=100,
    eval_strategy="steps",
    eval_steps=50,
)

# Initialize trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    processing_class=tokenizer,
)

# Start training
trainer.train()
```

<Tip>
When using a dataset with a "messages" field (like the example above), the SFTTrainer automatically applies the model's chat template, which it retrieves from the hub. This means you don't need any additional configuration to handle chat-style conversations - the trainer will format the messages according to the model's expected template format.
</Tip>

## Packing the Dataset

The SFTTrainer supports example packing to optimize training efficiency. This feature allows multiple short examples to be packed into the same input sequence, maximizing GPU utilization during training. To enable packing, simply set `packing=True` in the SFTConfig constructor. When using packed datasets with `max_steps`, be aware that you may train for more epochs than expected depending on your packing configuration. You can customize how examples are combined using a formatting function - particularly useful when working with datasets that have multiple fields like question-answer pairs. For evaluation datasets, you can disable packing by setting `eval_packing=False` in the SFTConfig. Here's a basic example of customizing the packing configuration:

```python
# Configure packing
training_args = SFTConfig(packing=True)

trainer = SFTTrainer(model=model, train_dataset=dataset, args=training_args)

trainer.train()
```

When packing the dataset with multiple fields, you can define a custom formatting function to combine the fields into a single input sequence. This function should take a list of examples and return a dictionary with the packed input sequence. Here's an example of a custom formatting function:

```python
def formatting_func(example):
    text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
    return text


training_args = SFTConfig(packing=True)
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    args=training_args,
    formatting_func=formatting_func,
)
```

## Monitoring Training Progress

Effective monitoring is crucial for successful fine-tuning. Let's explore what to watch for during training:

### Understanding Loss Patterns

Training loss typically follows three distinct phases:
1. Initial Sharp Drop: Rapid adaptation to new data distribution
2. Gradual Stabilization: Learning rate slows as model fine-tunes
3. Convergence: Loss values stabilize, indicating training completion

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/nlp_course_sft_loss_graphic.png" alt="SFTTrainer Training" />

### Metrics to Monitor

Effective monitoring involves tracking quantitative metrics, and evaluating qualitative metrics. Available metrics are:

- Training loss
- Validation loss
- Learning rate progression
- Gradient norms

<Tip warning={true}>
Watch for these warning signs during training:
1. Validation loss increasing while training loss decreases (overfitting)
2. No significant improvement in loss values (underfitting)
3. Extremely low loss values (potential memorization)
4. Inconsistent output formatting (template learning issues)
</Tip>

### The Path to Convergence

As training progresses, the loss curve should gradually stabilize. The key indicator of healthy training is a small gap between training and validation loss, suggesting 
the model is learning generalizable patterns rather than memorizing specific examples. The absolute loss values will vary depending on your task and dataset.

### Monitoring Training Progress

The graph above shows a typical training progression. Notice how both training and validation loss decrease sharply at first, then gradually level off. This pattern indicates the model is learning effectively while maintaining generalization ability.

### Warning Signs to Watch For

Several patterns in the loss curves can indicate potential issues. Below we illustrate common warning signs and solutions that we can consider.

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/sft_loss_1.png" alt="SFTTrainer Training" />

If the validation loss decreases at a significantly slower rate than training loss, your model is likely overfitting to the training data. Consider:
- Reducing the training steps
- Increasing the dataset size
- Validating dataset quality and diversity

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/sft_loss_2.png" alt="SFTTrainer Training" />

If the loss doesn't show significant improvement, the model might be:
- Learning too slowly (try increasing the learning rate)
- Struggling with the task (check data quality and task complexity)
- Hitting architecture limitations (consider a different model)

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/sft_loss_3.png" alt="SFTTrainer Training" />

Extremely low loss values could suggest memorization rather than learning. This is particularly concerning if:
- The model performs poorly on new, similar examples
- The outputs lack diversity
- The responses are too similar to training examples

<Tip warning={true}>
Monitor both the loss values and the model's actual outputs during training. Sometimes the loss can look good while the model develops unwanted behaviors. Regular qualitative evaluation of the model's responses helps catch issues that metrics alone might miss.
</Tip>

We should note that the interpretation of the loss values we outline here is aimed on the most common case, and in fact, loss values can behave on various ways depending on the model, the dataset, the training parameters, etc. If you interested in exploring more about outlined patterns, you should check out this blog post by the people at [Fast AI](https://www.fast.ai/posts/2023-09-04-learning-jumps/).

## Evaluation after SFT

In section [11.4](/en/chapter11/4) we will learn how to evaluate the model using benchmark datasets. For now, we will focus on the qualitative evaluation of the model.

After completing SFT, consider these follow-up actions:

1. Evaluate the model thoroughly on held-out test data
2. Validate template adherence across various inputs
3. Test domain-specific knowledge retention
4. Monitor real-world performance metrics

<Tip>
Document your training process, including:
- Dataset characteristics
- Training parameters
- Performance metrics
- Known limitations
This documentation will be valuable for future model iterations.
</Tip>

## Quiz

### 1. What parameters control the training duration in SFT?

<Question
	choices={[
		{
			text: "num_train_epochs and max_steps",
			explain: "Correct! These parameters determine how long the model will train, either by number of epochs or total steps.",
			correct: true
		},
		{
			text: "batch_size and learning_rate",
			explain: "While these affect training, they don't directly control the duration."
		},
		{
			text: "gradient_checkpointing and warmup_ratio",
			explain: "These parameters affect training efficiency and stability, not duration."
		}
	]}
/>

### 2. Which pattern in the loss curves indicates potential overfitting?

<Question
    choices={[
        {
            text: "Validation loss increases while training loss continues to decrease",
            explain: "Correct! This divergence between training and validation loss is a classic sign of overfitting.",
            correct: true
        },
        {
            text: "Both training and validation loss decrease steadily",
            explain: "This pattern actually indicates healthy training."
        },
        {
            text: "Training loss remains constant while validation loss decreases",
            explain: "This would be an unusual pattern and doesn't indicate overfitting."
        }
    ]}
/>

### 3. What is gradient_accumulation_steps used for?

<Question
    choices={[
        {
            text: "To increase effective batch size without using more memory",
            explain: "Correct! It accumulates gradients across multiple forward passes before updating weights.",
            correct: true
        },
        {
            text: "To save checkpoints during training",
            explain: "This is handled by save_steps and save_strategy parameters."
        },
        {
            text: "To control the learning rate schedule",
            explain: "Learning rate scheduling is controlled by learning_rate and warmup_ratio."
        }
    ]}
/>

### 4. What should you monitor during SFT training?

<Question
    choices={[
        {
            text: "Both quantitative metrics and qualitative outputs",
            explain: "Correct! Monitoring both types of metrics helps catch all potential issues.",
            correct: true
        },
        {
            text: "Only the training loss",
            explain: "Training loss alone isn't sufficient to ensure good model behavior."
        },
        {
            text: "Only the model's output quality",
            explain: "While important, qualitative evaluation alone misses important training dynamics."
        }
    ]}
/>

### 5. What indicates healthy convergence during training?

<Question
    choices={[
        {
            text: "A small gap between training and validation loss",
            explain: "Correct! This indicates the model is learning generalizable patterns.",
            correct: true
        },
        {
            text: "Training loss reaching zero",
            explain: "Extremely low loss values might indicate memorization rather than learning."
        },
        {
            text: "Validation loss being lower than training loss",
            explain: "This would be unusual and might indicate problems with the validation set."
        }
    ]}
/>

## 💐 Nice work!

You've learned how to fine-tune models using SFT! To continue your learning:
1. Try the notebook with different parameters
2. Experiment with other datasets
3. Contribute improvements to the course material

## Additional Resources

- [TRL Documentation](https://huggingface.co/docs/trl)
- [SFT Examples Repository](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py)
- [Fine-tuning Best Practices](https://huggingface.co/docs/transformers/training)
